Tril ================= 返回输入张量的下三角部分,其余部分设置为零 对于输入形状为 :math:`(*, H, W)` 的张量,返回每个 :math:`H \times W` 矩阵的下三角部分。 .. math:: \text{output}[i,j] = \begin{cases} \text{input}[i,j], & \text{if } j \leq i + k \\ 0, & \text{otherwise} \end{cases} 其中 :math:`k` 为对角线偏移量: - :math:`k = 0`:主对角线(默认) - :math:`k > 0`:主对角线上方第 k 条对角线 - :math:`k < 0`:主对角线下方第 k 条对角线 输入: - **src** - 输入数据地址。 - **k** - 对角线偏移量,默认为 0。 - **height** - 矩阵高度。 - **width** - 矩阵宽度。 - **out_elems** - 输出矩阵的数量(批量大小)。 - **core_mask(int, 可选)** - 核掩码(仅适用于共享存储版本)。 输出: - **dst** - 计算结果地址。 支持平台: ``FT78NE`` ``MT7004`` .. note:: - FT78NE 支持int8, int16, int32, fp32, fp64, cplx64, cplx128 - MT7004 支持fp16, fp32, int16, int32, cplx64 **共享存储版本:** .. c:function:: void i8_tril_s(int8_t* dst, int8_t* src, int core_mask, int64_t k, int64_t height, int64_t width, int64_t out_elems) .. c:function:: void i16_tril_s(int16_t* dst, int16_t* src, int core_mask, int64_t k, int64_t height, int64_t width, int64_t out_elems) .. c:function:: void i32_tril_s(int32_t* dst, int32_t* src, int core_mask, int64_t k, int64_t height, int64_t width, int64_t out_elems) .. c:function:: void hp_tril_s(half* dst, half* src, int core_mask, int64_t k, int64_t height, int64_t width, int64_t out_elems) .. c:function:: void fp_tril_s(float* dst, float* src, int core_mask, int64_t k, int64_t height, int64_t width, int64_t out_elems) .. c:function:: void dp_tril_s(double* dst, double* src, int core_mask, int64_t k, int64_t height, int64_t width, int64_t out_elems) .. c:function:: void c64_tril_s(float (*dst)[2], float (*src)[2], int core_mask, int64_t k, int64_t height, int64_t width, int64_t out_elems) .. c:function:: void c128_tril_s(double (*dst)[2], double (*src)[2], int core_mask, int64_t k, int64_t height, int64_t width, int64_t out_elems) **C调用示例:** .. code-block:: c :linenos: :emphasize-lines: 13 //FT78NE示例 #include #include int main(int argc, char* argv[]) { float *input = (float *)0xA0000000; //input在DDR空间 float *output = (float *)0xB0000000; //output在DDR空间 int64_t k = 0; // 主对角线 int64_t height = 4; // 矩阵高度 int64_t width = 4; // 矩阵宽度 int64_t out_elems = 1; // 矩阵数量 int core_mask = 0xff; fp_tril_s(output, input, core_mask, k, height, width, out_elems); return 0; } **私有存储版本:** .. c:function:: void i8_tril_p(int8_t* dst, int8_t* src, int64_t k, int64_t height, int64_t width, int64_t out_elems) .. c:function:: void i16_tril_p(int16_t* dst, int16_t* src, int64_t k, int64_t height, int64_t width, int64_t out_elems) .. c:function:: void i32_tril_p(int32_t* dst, int32_t* src, int64_t k, int64_t height, int64_t width, int64_t out_elems) .. c:function:: void hp_tril_p(half* dst, half* src, int64_t k, int64_t height, int64_t width, int64_t out_elems) .. c:function:: void fp_tril_p(float* dst, float* src, int64_t k, int64_t height, int64_t width, int64_t out_elems) .. c:function:: void dp_tril_p(double* dst, double* src, int64_t k, int64_t height, int64_t width, int64_t out_elems) .. c:function:: void c64_tril_p(float (*dst)[2], float (*src)[2], int64_t k, int64_t height, int64_t width, int64_t out_elems) .. c:function:: void c128_tril_p(double (*dst)[2], double (*src)[2], int64_t k, int64_t height, int64_t width, int64_t out_elems) **C调用示例:** .. code-block:: c :linenos: :emphasize-lines: 12 //FT78NE示例 #include #include int main(int argc, char* argv[]) { float *input = (float *)0x10810000; //input在L2空间 float *output = (float *)0x10850000; //output在L2空间 int64_t k = 0; // 主对角线 int64_t height = 4; // 矩阵高度 int64_t width = 4; // 矩阵宽度 int64_t out_elems = 1; // 矩阵数量 fp_tril_p(output, input, k, height, width, out_elems); return 0; } .. note:: **示例说明:** 对于 4x4 矩阵,k=0 时的下三角输出示例: .. code-block:: text 输入矩阵: 输出矩阵: 1 2 3 4 1 0 0 0 5 6 7 8 => 5 6 0 0 9 10 11 12 9 10 11 0 13 14 15 16 13 14 15 16 k=1 时(上移一条对角线,包含更多元素): .. code-block:: text 输入矩阵: 输出矩阵: 1 2 3 4 1 2 0 0 5 6 7 8 => 5 6 7 0 9 10 11 12 9 10 11 12 13 14 15 16 13 14 15 16 k=-1 时(下移一条对角线,更严格的下三角): .. code-block:: text 输入矩阵: 输出矩阵: 1 2 3 4 0 0 0 0 5 6 7 8 => 5 0 0 0 9 10 11 12 9 10 0 0 13 14 15 16 13 14 15 0